from .patch_utils import GeesPatchesManager
from .wrappers.gpt_model_wrapper import get_num_layers_to_build_by_plan
from ..get_data import DataLoader1
from ..pipeline.models.model_llama import get_llama_model


def patch():
    # GeesPatchesManager.register_patch('megatron.core.transformer.transformer_block.get_num_layers_to_build', get_num_layers_to_build_by_plan)
    # GeesPatchesManager.register_patch('megatron.training.iniitialize.initialize_megatron')
    GeesPatchesManager.register_patch('torch.utils.data.DataLoader',DataLoader1)
    GeesPatchesManager.register_patch('transformers.models.llama.modeling_llama.LlamaForCausalLM',get_llama_model)
    GeesPatchesManager.apply_patches()
 
patch()